# coding=utf-8
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import copy
import math

from os.path import join as pjoin

import torch
import torch.nn as nn
import numpy as np

from torch.nn import CrossEntropyLoss, Dropout, Softmax, Linear, Conv2d, LayerNorm
from torch.nn.modules.utils import _pair
from scipy import ndimage


ATTENTION_Q = "MultiHeadDotProductAttention_1/query"
ATTENTION_K = "MultiHeadDotProductAttention_1/key"
ATTENTION_V = "MultiHeadDotProductAttention_1/value"
ATTENTION_OUT = "MultiHeadDotProductAttention_1/out"
FC_0 = "MlpBlock_3/Dense_0"
FC_1 = "MlpBlock_3/Dense_1"
ATTENTION_NORM = "LayerNorm_0"
MLP_NORM = "LayerNorm_2"

def transfg_config(dtype='ViT-B_16'):
    """transfg模型参数

    Args:
        dtype (str): ViT-B_16|ViT-B_32|ViT-L_16|ViT-L_32|ViT-H_14, default:ViT-B_16

    Returns:
        dict: 返回ViT模型输入参数
    """
    base = {"hidden_size": 768}
    embedding = {
        "patch_size": (16, 16),
        "split": 'non-overlap',
        "slide_step": 12,
        "embedding_dropout_rate": 0.1,
    }

    transformer =  {
        "mlp_dim": 3072,
        "num_heads": 12,
        "num_layers": 12,
        "attention_dropout_rate": 0.0,
        "mlp_dropout_rate": 0.1
    }

    if dtype == 'ViT-B_32':
        embedding['patch_size'] = (32, 32)
    elif dtype == 'ViT-L_16':
        base['hidden_size'] = 1024
        transformer['mlp_dim'] = 4096 
        transformer['num_heads'] = 16
        transformer['num_layers'] = 24
    elif dtype == 'ViT-L_32':
        embedding['patch_size'] = (32, 32)
        base['hidden_size'] = 1024
        transformer['mlp_dim'] = 4096 
        transformer['num_heads'] = 16
        transformer['num_layers'] = 24
    elif dtype == 'ViT-H_14':
        embedding['patch_size'] = (14, 14)
        base['hidden_size'] = 1280
        transformer['mlp_dim'] = 5120 
        transformer['num_heads'] = 16
        transformer['num_layers'] = 32 
    else:
        pass

    return dict(base, **embedding, **transformer)


def np2th(weights, conv=False):
    """Possibly convert HWIO to OIHW."""
    if conv:
        weights = weights.transpose([3, 2, 0, 1])
    return torch.from_numpy(weights)

def swish(x):
    return x * torch.sigmoid(x)

ACT2FN = {"gelu": torch.nn.functional.gelu, "relu": torch.nn.functional.relu, "swish": swish}

class LabelSmoothing(nn.Module):
    """
    NLL loss with label smoothing.
    """
    def __init__(self, smoothing=0.0):
        """
        Constructor for the LabelSmoothing module.
        :param smoothing: label smoothing factor
        """
        super(LabelSmoothing, self).__init__()
        self.confidence = 1.0 - smoothing
        self.smoothing = smoothing

    def forward(self, x, target):
        logprobs = torch.nn.functional.log_softmax(x, dim=-1)

        nll_loss = -logprobs.gather(dim=-1, index=target.unsqueeze(1))
        nll_loss = nll_loss.squeeze(1)
        smooth_loss = -logprobs.mean(dim=-1)
        loss = self.confidence * nll_loss + self.smoothing * smooth_loss
        return loss.mean()

class Attention(nn.Module):
    def __init__(self, num_heads, hidden_size, attention_dropout_rate):
        """自注意力模块

        Args:
            num_heads (int): 注意力头的数量
            hidden_size (int): 隐藏特征维度
            attention_dropout_rate (float): 注意力dropout
        """
        super(Attention, self).__init__()
        self.num_attention_heads = num_heads
        self.attention_head_size = int(hidden_size / self.num_attention_heads)
        self.all_head_size = self.num_attention_heads * self.attention_head_size

        self.query = Linear(hidden_size, self.all_head_size)
        self.key = Linear(hidden_size, self.all_head_size)
        self.value = Linear(hidden_size, self.all_head_size)

        self.out = Linear(hidden_size, hidden_size)
        self.attn_dropout = Dropout(attention_dropout_rate)
        self.proj_dropout = Dropout(attention_dropout_rate)

        self.softmax = Softmax(dim=-1)

    def transpose_for_scores(self, x):
        new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
        x = x.view(*new_x_shape)
        return x.permute(0, 2, 1, 3)

    def forward(self, hidden_states, mask=None):
        mixed_query_layer = self.query(hidden_states)
        mixed_key_layer = self.key(hidden_states)
        mixed_value_layer = self.value(hidden_states)

        query_layer = self.transpose_for_scores(mixed_query_layer)
        key_layer = self.transpose_for_scores(mixed_key_layer)
        value_layer = self.transpose_for_scores(mixed_value_layer)

        attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
        attention_scores = attention_scores / math.sqrt(self.attention_head_size)
        if mask is not None:
            attention_scores = attention_scores.masked_fill(mask==0, value=torch.tensor(-1e9))
        attention_probs = self.softmax(attention_scores)
        weights = attention_probs
        attention_probs = self.attn_dropout(attention_probs)

        context_layer = torch.matmul(attention_probs, value_layer)
        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
        new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
        context_layer = context_layer.view(*new_context_layer_shape)
        attention_output = self.out(context_layer)
        attention_output = self.proj_dropout(attention_output)
        return attention_output, weights

class Mlp(nn.Module):
    def __init__(self, hidden_size, mlp_dim, dropout_rate):
        """ 多层感知机

        Args:
            hidden_size (int): 隐藏特征维度
            mlp_dim (int): 多层感知机维度
            dropout_rate (float): 感知机dropout rate
        """
        super(Mlp, self).__init__()
        self.fc1 = Linear(hidden_size, mlp_dim)
        self.fc2 = Linear(mlp_dim, hidden_size)
        self.act_fn = ACT2FN["gelu"]
        self.dropout = Dropout(p=dropout_rate)

        self._init_weights()

    def _init_weights(self):
        nn.init.xavier_uniform_(self.fc1.weight)
        nn.init.xavier_uniform_(self.fc2.weight)
        nn.init.normal_(self.fc1.bias, std=1e-6)
        nn.init.normal_(self.fc2.bias, std=1e-6)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act_fn(x)
        x = self.dropout(x)
        x = self.fc2(x)
        x = self.dropout(x)
        return x

class Embeddings(nn.Module):
    """Construct the embeddings from patch, position embeddings.
    """
    def __init__(self, img_size, patch_size, hidden_size, slide_step, dropout_rate, split='non-overlap', in_channels=3):
        """图片Embedding

        Args:
            img_size (int， tuple): 图片尺寸
            patch_size (int, tuple): Patch的尺寸
            hidden_size (int): 隐藏特征维度
            slide_step (int): 滑动步长，如果使用split:"overlap"模式
            dropout_rate (float): Embedding的dropout rate
            split (str, optional): 两种patch分割模式，'non-overlap'和'overlap'. Defaults to 'non-overlap'.
            in_channels (int, optional): 输入图像的通道数. Defaults to 3.
        """
        super(Embeddings, self).__init__()
        self.hybrid = None
        img_size = _pair(img_size)

        patch_size = _pair(patch_size)
        if split == 'non-overlap':
            n_patches = (img_size[0] // patch_size[0]) * (img_size[1] // patch_size[1])
            self.patch_embeddings = Conv2d(in_channels=in_channels,
                                       out_channels=hidden_size,
                                       kernel_size=patch_size,
                                       stride=patch_size)
        elif split == 'overlap':
            n_patches = ((img_size[0] - patch_size[0]) // slide_step + 1) * ((img_size[1] - patch_size[1]) // slide_step + 1)
            self.patch_embeddings = Conv2d(in_channels=in_channels,
                                        out_channels=hidden_size,
                                        kernel_size=patch_size,
                                        stride=(slide_step, slide_step))
        self.position_embeddings = nn.Parameter(torch.zeros(1, n_patches+1, hidden_size))
        self.cls_token = nn.Parameter(torch.zeros(1, 1, hidden_size))

        self.dropout = Dropout(dropout_rate)

    def forward(self, x):
        B = x.shape[0]
        cls_tokens = self.cls_token.expand(B, -1, -1)

        if self.hybrid:
            x = self.hybrid_model(x)
        x = self.patch_embeddings(x)
        x = x.flatten(2)
        x = x.transpose(-1, -2)
        x = torch.cat((cls_tokens, x), dim=1)

        embeddings = x + self.position_embeddings
        embeddings = self.dropout(embeddings)
        return embeddings

class Block(nn.Module):
    def __init__(self, hidden_size, mlp_dim, num_heads, mlp_dropout_rate, attention_dropout_rate):
        """Self-Attention模块

        Args:
            hidden_size (int): 隐藏特征维度
            mlp_dim (int): 多层感知机维度
            num_heads (int): 多头注意力数量
            mlp_dropout_rate (float): 多层感知机dropout rate
            attention_dropout_rate (float): 注意力dropout rate
        """
        super(Block, self).__init__()
        self.hidden_size = hidden_size
        self.attention_norm = LayerNorm(hidden_size, eps=1e-6)
        self.ffn_norm = LayerNorm(hidden_size, eps=1e-6)
        self.ffn = Mlp(hidden_size=hidden_size, mlp_dim=mlp_dim, dropout_rate=mlp_dropout_rate)
        self.attn = Attention(num_heads=num_heads, hidden_size=hidden_size, attention_dropout_rate=attention_dropout_rate)

    def forward(self, x, mask=None):
        h = x
        x = self.attention_norm(x)
        x, weights = self.attn(x, mask)
        x = x + h

        h = x
        x = self.ffn_norm(x)
        x = self.ffn(x)
        x = x + h
        return x, weights

    def load_from(self, weights, n_block):
        ROOT = f"Transformer/encoderblock_{n_block}"
        with torch.no_grad():
            query_weight = np2th(weights[pjoin(ROOT, ATTENTION_Q, "kernel")]).view(self.hidden_size, self.hidden_size).t()
            key_weight = np2th(weights[pjoin(ROOT, ATTENTION_K, "kernel")]).view(self.hidden_size, self.hidden_size).t()
            value_weight = np2th(weights[pjoin(ROOT, ATTENTION_V, "kernel")]).view(self.hidden_size, self.hidden_size).t()
            out_weight = np2th(weights[pjoin(ROOT, ATTENTION_OUT, "kernel")]).view(self.hidden_size, self.hidden_size).t()

            query_bias = np2th(weights[pjoin(ROOT, ATTENTION_Q, "bias")]).view(-1)
            key_bias = np2th(weights[pjoin(ROOT, ATTENTION_K, "bias")]).view(-1)
            value_bias = np2th(weights[pjoin(ROOT, ATTENTION_V, "bias")]).view(-1)
            out_bias = np2th(weights[pjoin(ROOT, ATTENTION_OUT, "bias")]).view(-1)

            self.attn.query.weight.copy_(query_weight)
            self.attn.key.weight.copy_(key_weight)
            self.attn.value.weight.copy_(value_weight)
            self.attn.out.weight.copy_(out_weight)
            self.attn.query.bias.copy_(query_bias)
            self.attn.key.bias.copy_(key_bias)
            self.attn.value.bias.copy_(value_bias)
            self.attn.out.bias.copy_(out_bias)

            mlp_weight_0 = np2th(weights[pjoin(ROOT, FC_0, "kernel")]).t()
            mlp_weight_1 = np2th(weights[pjoin(ROOT, FC_1, "kernel")]).t()
            mlp_bias_0 = np2th(weights[pjoin(ROOT, FC_0, "bias")]).t()
            mlp_bias_1 = np2th(weights[pjoin(ROOT, FC_1, "bias")]).t()

            self.ffn.fc1.weight.copy_(mlp_weight_0)
            self.ffn.fc2.weight.copy_(mlp_weight_1)
            self.ffn.fc1.bias.copy_(mlp_bias_0)
            self.ffn.fc2.bias.copy_(mlp_bias_1)

            self.attention_norm.weight.copy_(np2th(weights[pjoin(ROOT, ATTENTION_NORM, "scale")]))
            self.attention_norm.bias.copy_(np2th(weights[pjoin(ROOT, ATTENTION_NORM, "bias")]))
            self.ffn_norm.weight.copy_(np2th(weights[pjoin(ROOT, MLP_NORM, "scale")]))
            self.ffn_norm.bias.copy_(np2th(weights[pjoin(ROOT, MLP_NORM, "bias")]))

class Part_Attention(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        """根据cls_token与其他token的关系，选择最相关的token

        Args:
            x (list): 输入n-1个block的Attention weights列表

        Returns:
            torch.Tensor, torch.Tensor: cls_token与其他token的关系值，多头注意力中每一个头中
        """
        length = len(x)
        last_map = x[0]
        for i in range(1, length):
            last_map = torch.matmul(x[i], last_map)
        last_map = last_map[:,:,0,1:]

        _, max_inx = last_map.max(2)

        return last_map, max_inx

class Encoder(nn.Module):
    """编码器: tokens to tokens

        Args:
            num_layers (int): Self-Attention的数量
            hidden_size (int): 隐藏特征维度
            mlp_dim (int): 多层感知维度
            num_heads (int): 多头注意力数量
            mlp_dropout_rate (float): 多层感知dropout rate
            attention_dropout_rate (float): 注意力dropout rate
    """
    def __init__(self, num_layers, hidden_size, mlp_dim, num_heads, mlp_dropout_rate, attention_dropout_rate):
        
        super(Encoder, self).__init__()
        self.layer = nn.ModuleList()

        block_params = {
            "hidden_size": hidden_size,
            "mlp_dim": mlp_dim,
            "num_heads": num_heads,
            "mlp_dropout_rate": mlp_dropout_rate,
            "attention_dropout_rate": attention_dropout_rate
        }
        for _ in range(num_layers - 1):
            layer = Block(**block_params)
            self.layer.append(copy.deepcopy(layer))
        self.part_select = Part_Attention()
        self.part_layer = Block(**block_params)
        self.part_norm = LayerNorm(hidden_size, eps=1e-6)

    def forward(self, hidden_states, mask=None):
        attn_weights = []
        for layer in self.layer:
            hidden_states, weights = layer(hidden_states, mask)
            attn_weights.append(weights)            
        part_wgts, part_inx = self.part_select(attn_weights)
        part_inx = part_inx + 1
        parts = []
        B = part_inx.size(0)
        for i in range(B):
            parts.append(hidden_states[i, part_inx[i,:]])
        parts = torch.stack(parts).squeeze(1)
        concat = torch.cat((hidden_states[:,0].unsqueeze(1), parts), dim=1)
        part_states, _ = self.part_layer(concat)
        part_encoded = self.part_norm(part_states) 

        return {"tokens": part_encoded, "connect": part_wgts, "select": part_inx}


class Transformer(nn.Module):
    """ViT Transformer模块 (tokens to tokens)

        Args:
            img_size (int,tuple): 输入图片尺寸
            patch_size (int,tuple): 输入patch大小尺寸
            hidden_size (int): 隐藏特征维度
            split (str): patch切分方式, 可选'non-overlap' 或 'overlap'
            slide_step (int): 配合split使用
            num_layers (int): Self-Attention的数量
            mlp_dim (int): 多层感知维度
            num_heads (int): 多头注意力数量
            embedding_dropout_rate (float): Embeding的dropout rate
            mlp_dropout_rate (float): 多层感知dropout rate
            attention_dropout_rate (float): 注意力dropout rate
            in_channels (int, optional): _description_. Defaults to 3.
    """
    def __init__(self, img_size, patch_size, hidden_size, split, slide_step, num_layers, mlp_dim, num_heads, embedding_dropout_rate, mlp_dropout_rate, attention_dropout_rate, in_channels=3):
        super(Transformer, self).__init__()
        embedding_params = {
            "img_size": img_size,
            "patch_size": patch_size,
            "split": split,
            "slide_step": slide_step,
            "hidden_size": hidden_size,
            "dropout_rate": embedding_dropout_rate,
            "in_channels": in_channels
        }

        encoder_params = {
            "num_layers": num_layers,
            "hidden_size": hidden_size,
            "mlp_dim": mlp_dim,
            "num_heads": num_heads,
            "mlp_dropout_rate": mlp_dropout_rate,
            "attention_dropout_rate": attention_dropout_rate
        }

        self.embeddings = Embeddings(**embedding_params)
        self.encoder = Encoder(**encoder_params)

    def forward(self, input_ids, mask=None):
        embedding_output = self.embeddings(input_ids)
        part_encoded_dict = self.encoder(embedding_output, mask)
        return part_encoded_dict

class TransFG(nn.Module):
    def __init__(self, img_size, num_classes, patch_size, hidden_size, split, slide_step, num_layers, mlp_dim, num_heads, embedding_dropout_rate, mlp_dropout_rate, attention_dropout_rate, in_channels=3, **kwargs):
        """TransFG模型: image to tokens

        Args:
            img_size (int,tuple): 输入图片尺寸
            num_classes (int): 类别的数量
            patch_size (int,tuple): 输入patch大小尺寸
            hidden_size (int): 隐藏特征维度
            split (str): patch切分方式, 可选'non-overlap' 或 'overlap'
            slide_step (int): 配合split使用
            num_layers (int): Self-Attention的数量
            mlp_dim (int): 多层感知维度
            num_heads (int): 多头注意力数量
            embedding_dropout_rate (float): Embeding的dropout rate
            mlp_dropout_rate (float): 多层感知dropout rate
            attention_dropout_rate (float): 注意力dropout rate
            in_channels (int, optional): _description_. Defaults to 3.
        """
        super(TransFG, self).__init__()
        tparams = {
            "img_size": img_size,
            "patch_size": patch_size,
            "split": split,
            "slide_step": slide_step,
            "hidden_size": hidden_size,
            "in_channels": in_channels,
            "num_layers": num_layers,
            "mlp_dim": mlp_dim,
            "num_heads": num_heads,
            "mlp_dropout_rate": mlp_dropout_rate,
            "embedding_dropout_rate": embedding_dropout_rate,
            "attention_dropout_rate": attention_dropout_rate

        }
        self.num_classes = num_classes
        self.classifier = "token" #config.classifier
        self.transformer = Transformer(**tparams)
        self.part_head = Linear(hidden_size, num_classes)
    
    def forward(self, x, mask=None):
        part_tokens_dict = self.transformer(x, mask)
        part_tokens = part_tokens_dict['tokens']
        part_logits = self.part_head(part_tokens[:, 0])

        return {"logits": part_logits, "tokens": part_tokens, "connect": part_tokens_dict['connect'], "select": part_tokens_dict['select']}
        return part_logits, part_tokens

        if labels is not None:
            if self.smoothing_value == 0:
                loss_fct = CrossEntropyLoss()
            else:
                loss_fct = LabelSmoothing(self.smoothing_value)
            part_loss = loss_fct(part_logits.view(-1, self.num_classes), labels.view(-1))
            contrast_loss = con_loss(part_tokens[:, 0], labels.view(-1))
            loss = part_loss + contrast_loss
            return loss, part_logits
        else:
            return part_logits

    def load_from(self, weights):
        with torch.no_grad():
            self.transformer.embeddings.patch_embeddings.weight.copy_(np2th(weights["embedding/kernel"], conv=True))
            self.transformer.embeddings.patch_embeddings.bias.copy_(np2th(weights["embedding/bias"]))
            self.transformer.embeddings.cls_token.copy_(np2th(weights["cls"]))
            self.transformer.encoder.part_norm.weight.copy_(np2th(weights["Transformer/encoder_norm/scale"]))
            self.transformer.encoder.part_norm.bias.copy_(np2th(weights["Transformer/encoder_norm/bias"]))

            posemb = np2th(weights["Transformer/posembed_input/pos_embedding"])
            posemb_new = self.transformer.embeddings.position_embeddings
            if posemb.size() == posemb_new.size():
                self.transformer.embeddings.position_embeddings.copy_(posemb)
            else:
                #logger.info("load_pretrained: resized variant: %s to %s" % (posemb.size(), posemb_new.size()))
                ntok_new = posemb_new.size(1)

                if self.classifier == "token":
                    posemb_tok, posemb_grid = posemb[:, :1], posemb[0, 1:]
                    ntok_new -= 1
                else:
                    posemb_tok, posemb_grid = posemb[:, :0], posemb[0]

                gs_old = int(np.sqrt(len(posemb_grid)))
                gs_new = int(np.sqrt(ntok_new))
                print('load_pretrained: grid-size from %s to %s' % (gs_old, gs_new))
                posemb_grid = posemb_grid.reshape(gs_old, gs_old, -1)

                zoom = (gs_new / gs_old, gs_new / gs_old, 1)
                posemb_grid = ndimage.zoom(posemb_grid, zoom, order=1)
                posemb_grid = posemb_grid.reshape(1, gs_new * gs_new, -1)
                posemb = np.concatenate([posemb_tok, posemb_grid], axis=1)
                self.transformer.embeddings.position_embeddings.copy_(np2th(posemb))

            for bname, block in self.transformer.encoder.named_children():
                if bname.startswith('part') == False:
                    for uname, unit in block.named_children():
                        unit.load_from(weights, n_block=uname)

            if self.transformer.embeddings.hybrid:
                self.transformer.embeddings.hybrid_model.root.conv.weight.copy_(np2th(weights["conv_root/kernel"], conv=True))
                gn_weight = np2th(weights["gn_root/scale"]).view(-1)
                gn_bias = np2th(weights["gn_root/bias"]).view(-1)
                self.transformer.embeddings.hybrid_model.root.gn.weight.copy_(gn_weight)
                self.transformer.embeddings.hybrid_model.root.gn.bias.copy_(gn_bias)

                for bname, block in self.transformer.embeddings.hybrid_model.body.named_children():
                    for uname, unit in block.named_children():
                        unit.load_from(weights, n_block=bname, n_unit=uname) 